{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.utils import shuffle\n",
    "import re\n",
    "import time\n",
    "import collections\n",
    "import os\n",
    "import itertools\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words, atleast=1):\n",
    "    count = [['PAD', 0], ['GO', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    counter = collections.Counter(words).most_common(n_words)\n",
    "    counter = [i for i in counter if i[1] >= atleast]\n",
    "    count.extend(counter)\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('augment-normalizer-v4.json') as fopen:\n",
    "    texts = json.load(fopen)\n",
    "    \n",
    "before, after = [], []\n",
    "    \n",
    "for splitted in texts:\n",
    "    if len(splitted) < 2:\n",
    "        continue\n",
    "    if not len(splitted[0]):\n",
    "        continue\n",
    "    before.append(list(splitted[0]))\n",
    "    after.append(list(splitted[1]))\n",
    "    \n",
    "assert len(before) == len(after)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 28\n",
      "Most common words [('a', 1090958), ('l', 943383), ('e', 773153), ('n', 623036), ('r', 499905), ('x', 439435)]\n",
      "Sample data [4, 19, 4, 20, 9, 19, 4, 20, 9, 19] ['a', 'b', 'a', 'd', 'x', 'b', 'a', 'd', 'x', 'b']\n",
      "filtered vocab size: 32\n",
      "% of vocab used: 114.29%\n"
     ]
    }
   ],
   "source": [
    "concat_from = list(itertools.chain(*before))\n",
    "vocabulary_size_from = len(list(set(concat_from)))\n",
    "data_from, count_from, dictionary_from, rev_dictionary_from = build_dataset(concat_from, vocabulary_size_from)\n",
    "print('vocab from size: %d'%(vocabulary_size_from))\n",
    "print('Most common words', count_from[4:10])\n",
    "print('Sample data', data_from[:10], [rev_dictionary_from[i] for i in data_from[:10]])\n",
    "print('filtered vocab size:',len(dictionary_from))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary_from)/vocabulary_size_from,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 29\n",
      "Most common words [('a', 2164890), (' ', 1131495), ('l', 943383), ('k', 843343), ('h', 828089), ('t', 729459)]\n",
      "Sample data [4, 19, 4, 21, 9, 4, 7, 5, 4, 19] ['a', 'b', 'a', 'd', 't', 'a', 'k', ' ', 'a', 'b']\n",
      "filtered vocab size: 33\n",
      "% of vocab used: 113.78999999999999%\n"
     ]
    }
   ],
   "source": [
    "concat_to = list(itertools.chain(*after))\n",
    "vocabulary_size_to = len(list(set(concat_to)))\n",
    "data_to, count_to, dictionary_to, rev_dictionary_to = build_dataset(concat_to, vocabulary_size_to)\n",
    "print('vocab from size: %d'%(vocabulary_size_to))\n",
    "print('Most common words', count_to[4:10])\n",
    "print('Sample data', data_to[:10], [rev_dictionary_to[i] for i in data_to[:10]])\n",
    "print('filtered vocab size:',len(dictionary_to))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary_to)/vocabulary_size_to,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary_from['GO']\n",
    "PAD = dictionary_from['PAD']\n",
    "EOS = dictionary_from['EOS']\n",
    "UNK = dictionary_from['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(after)):\n",
    "    after[i].append('EOS')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Stemmer:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, \n",
    "                 from_dict_size, to_dict_size, learning_rate, \n",
    "                 dropout = 0.8, beam_width = 15, force_teaching_ratio=0.5):\n",
    "        \n",
    "        def lstm_cell(reuse=False):\n",
    "            return tf.nn.rnn_cell.LSTMCell(size_layer, reuse=reuse)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype=tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "\n",
    "        encoder_embeddings = tf.Variable(tf.random_uniform([from_dict_size, embedded_size], -1, 1))\n",
    "        encoder_embedded = tf.nn.embedding_lookup(encoder_embeddings, self.X)\n",
    "        encoder_cells = tf.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(num_layers)])\n",
    "        self.encoder_out, self.encoder_state = tf.nn.dynamic_rnn(cell = encoder_cells, \n",
    "                                                                 inputs = encoder_embedded, \n",
    "                                                                 sequence_length = self.X_seq_len,\n",
    "                                                                 dtype = tf.float32)\n",
    "        \n",
    "        encoder_state = tuple(self.encoder_state[-1] for _ in range(num_layers))\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        decoder_embeddings = tf.Variable(tf.random_uniform([to_dict_size, embedded_size], -1, 1))\n",
    "        dense_layer = tf.layers.Dense(to_dict_size)\n",
    "        \n",
    "        with tf.variable_scope('decode'):\n",
    "            attention_mechanism = tf.contrib.seq2seq.LuongAttention(\n",
    "            num_units = size_layer, \n",
    "            memory = encoder_embedded,\n",
    "            memory_sequence_length = self.X_seq_len)\n",
    "            decoder_cell = tf.contrib.seq2seq.AttentionWrapper(\n",
    "                cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(num_layers)]),\n",
    "                attention_mechanism = attention_mechanism,\n",
    "                attention_layer_size = size_layer)\n",
    "            main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "            decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "            training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(\n",
    "            inputs = tf.nn.embedding_lookup(decoder_embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                embedding = decoder_embeddings,\n",
    "                sampling_probability = 1 - force_teaching_ratio,\n",
    "                time_major = False)\n",
    "            training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cell,\n",
    "                helper = training_helper,\n",
    "                initial_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state),\n",
    "                output_layer = dense_layer)\n",
    "            training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "            \n",
    "        with tf.variable_scope('decode', reuse=True):\n",
    "            encoder_out_tiled = tf.contrib.seq2seq.tile_batch(encoder_embedded, beam_width)\n",
    "            encoder_state_tiled = tf.contrib.seq2seq.tile_batch(encoder_state, beam_width)\n",
    "            X_seq_len_tiled = tf.contrib.seq2seq.tile_batch(self.X_seq_len, beam_width)\n",
    "            attention_mechanism = tf.contrib.seq2seq.LuongAttention(\n",
    "                num_units = size_layer, \n",
    "                memory = encoder_out_tiled,\n",
    "                memory_sequence_length = X_seq_len_tiled)\n",
    "            decoder_cell = tf.contrib.seq2seq.AttentionWrapper(\n",
    "                cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell(reuse=True) for _ in range(num_layers)]),\n",
    "                attention_mechanism = attention_mechanism,\n",
    "                attention_layer_size = size_layer)\n",
    "            predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(\n",
    "                cell = decoder_cell,\n",
    "                embedding = decoder_embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS,\n",
    "                initial_state = decoder_cell.zero_state(batch_size * beam_width, tf.float32).clone(cell_state = encoder_state_tiled),\n",
    "                beam_width = beam_width,\n",
    "                output_layer = dense_layer,\n",
    "                length_penalty_weight = 0.0)\n",
    "            predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = False,\n",
    "                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "            \n",
    "            \n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        self.predicting_ids = tf.identity(predicting_decoder_output.predicted_ids[:, :, 0],name=\"logits\")\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        \n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 256\n",
    "num_layers = 2\n",
    "embedded_size = 128\n",
    "learning_rate = 1e-3\n",
    "batch_size = 128\n",
    "epoch = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gradients_impl.py:112: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Stemmer(size_layer, num_layers, embedded_size, len(dictionary_from), \n",
    "                len(dictionary_to), learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic, UNK=3):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i:\n",
    "            ints.append(dic.get(k, UNK))\n",
    "        X.append(ints)\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = str_idx(before, dictionary_from)\n",
    "Y = str_idx(after, dictionary_to)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.cross_validation import train_test_split\n",
    "train_X, test_X, train_Y, test_Y = train_test_split(X, Y, test_size=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 5984/5984 [09:41<00:00, 10.29it/s, accuracy=0.988, cost=0.0326]\n",
      "test minibatch loop: 100%|██████████| 665/665 [00:28<00:00, 23.28it/s, accuracy=0.969, cost=0.0738]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.970278\n",
      "epoch: 0, avg loss: 0.196654, avg accuracy: 0.939013\n",
      "epoch: 0, avg loss test: 0.094238, avg accuracy test: 0.970278\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 5984/5984 [09:38<00:00, 10.35it/s, accuracy=0.986, cost=0.0476]\n",
      "test minibatch loop: 100%|██████████| 665/665 [00:28<00:00, 23.39it/s, accuracy=0.979, cost=0.0654]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, pass acc: 0.970278, current acc: 0.977899\n",
      "epoch: 1, avg loss: 0.079812, avg accuracy: 0.974141\n",
      "epoch: 1, avg loss test: 0.067784, avg accuracy test: 0.977899\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 5984/5984 [09:37<00:00, 10.78it/s, accuracy=0.987, cost=0.0371]\n",
      "test minibatch loop: 100%|██████████| 665/665 [00:28<00:00, 23.38it/s, accuracy=0.986, cost=0.0392]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, pass acc: 0.977899, current acc: 0.983517\n",
      "epoch: 2, avg loss: 0.057337, avg accuracy: 0.980940\n",
      "epoch: 2, avg loss test: 0.049763, avg accuracy test: 0.983517\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 5984/5984 [09:38<00:00, 11.04it/s, accuracy=0.987, cost=0.0374]\n",
      "test minibatch loop: 100%|██████████| 665/665 [00:28<00:00, 23.38it/s, accuracy=0.992, cost=0.0214]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, pass acc: 0.983517, current acc: 0.984213\n",
      "epoch: 3, avg loss: 0.045526, avg accuracy: 0.984753\n",
      "epoch: 3, avg loss test: 0.048175, avg accuracy test: 0.984213\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 5984/5984 [09:38<00:00, 10.87it/s, accuracy=0.985, cost=0.0351]\n",
      "test minibatch loop: 100%|██████████| 665/665 [00:28<00:00, 23.31it/s, accuracy=0.988, cost=0.026] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, pass acc: 0.984213, current acc: 0.985816\n",
      "epoch: 4, avg loss: 0.038852, avg accuracy: 0.986804\n",
      "epoch: 4, avg loss test: 0.042522, avg accuracy test: 0.985816\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 5984/5984 [09:38<00:00, 10.85it/s, accuracy=0.992, cost=0.022] \n",
      "test minibatch loop: 100%|██████████| 665/665 [00:28<00:00, 23.34it/s, accuracy=0.991, cost=0.0312]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, pass acc: 0.985816, current acc: 0.986477\n",
      "epoch: 5, avg loss: 0.035392, avg accuracy: 0.987867\n",
      "epoch: 5, avg loss test: 0.040669, avg accuracy test: 0.986477\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 5984/5984 [09:38<00:00, 10.34it/s, accuracy=0.985, cost=0.0442] \n",
      "test minibatch loop: 100%|██████████| 665/665 [00:28<00:00, 23.36it/s, accuracy=0.98, cost=0.0506] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, pass acc: 0.986477, current acc: 0.987883\n",
      "epoch: 6, avg loss: 0.032731, avg accuracy: 0.988687\n",
      "epoch: 6, avg loss test: 0.035074, avg accuracy test: 0.987883\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 5984/5984 [09:37<00:00, 10.54it/s, accuracy=0.997, cost=0.0127] \n",
      "test minibatch loop: 100%|██████████| 665/665 [00:28<00:00, 23.35it/s, accuracy=0.99, cost=0.0258]  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 7, pass acc: 0.987883, current acc: 0.988393\n",
      "epoch: 7, avg loss: 0.031217, avg accuracy: 0.989066\n",
      "epoch: 7, avg loss test: 0.033957, avg accuracy test: 0.988393\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 5984/5984 [09:37<00:00, 11.08it/s, accuracy=0.985, cost=0.0565] \n",
      "test minibatch loop: 100%|██████████| 665/665 [00:28<00:00, 23.31it/s, accuracy=0.991, cost=0.028]  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, pass acc: 0.988393, current acc: 0.989235\n",
      "epoch: 8, avg loss: 0.029715, avg accuracy: 0.989519\n",
      "epoch: 8, avg loss test: 0.031453, avg accuracy test: 0.989235\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop:   8%|▊         | 508/5984 [00:49<08:45, 10.41it/s, accuracy=0.994, cost=0.0164]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "test minibatch loop: 100%|██████████| 665/665 [00:28<00:00, 23.34it/s, accuracy=0.987, cost=0.0427] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 9, avg loss: 0.028808, avg accuracy: 0.989746\n",
      "epoch: 9, avg loss test: 0.033073, avg accuracy test: 0.988883\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop:  32%|███▏      | 1885/5984 [03:01<06:26, 10.60it/s, accuracy=0.991, cost=0.0232] IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  51%|█████     | 3058/5984 [04:55<04:45, 10.25it/s, accuracy=0.99, cost=0.0261]  IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  73%|███████▎  | 4380/5984 [07:03<02:42,  9.89it/s, accuracy=0.987, cost=0.0376] IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop:  93%|█████████▎| 5567/5984 [08:58<00:41,  9.94it/s, accuracy=0.988, cost=0.0239] IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 3, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "    total_loss, total_accuracy, total_loss_test, total_accuracy_test = 0, 0, 0, 0\n",
    "    train_X, train_Y = shuffle(train_X, train_Y)\n",
    "    test_X, test_Y = shuffle(test_X, test_Y)\n",
    "    pbar = tqdm(range(0, len(train_X), batch_size), desc='train minibatch loop')\n",
    "    for k in pbar:\n",
    "        batch_x, _ = pad_sentence_batch(train_X[k: min(k+batch_size,len(train_X))], PAD)\n",
    "        batch_y, _ = pad_sentence_batch(train_Y[k: min(k+batch_size,len(train_X))], PAD)\n",
    "        acc, loss, _ = sess.run([model.accuracy, model.cost, model.optimizer], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss += loss\n",
    "        total_accuracy += acc\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_X), batch_size), desc='test minibatch loop')\n",
    "    for k in pbar:\n",
    "        batch_x, _ = pad_sentence_batch(test_X[k: min(k+batch_size,len(test_X))], PAD)\n",
    "        batch_y, _ = pad_sentence_batch(test_Y[k: min(k+batch_size,len(test_X))], PAD)\n",
    "        acc, loss = sess.run([model.accuracy, model.cost], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss_test += loss\n",
    "        total_accuracy_test += acc\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc)\n",
    "        \n",
    "    total_loss /= (len(train_X) / batch_size)\n",
    "    total_accuracy /= (len(train_X) / batch_size)\n",
    "    total_loss_test /= (len(test_X) / batch_size)\n",
    "    total_accuracy_test /= (len(test_X) / batch_size)\n",
    "    \n",
    "    if total_accuracy_test > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, total_accuracy_test)\n",
    "        )\n",
    "        CURRENT_ACC = total_accuracy_test\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('epoch: %d, avg loss: %f, avg accuracy: %f'%(EPOCH, total_loss, total_accuracy))\n",
    "    print('epoch: %d, avg loss test: %f, avg accuracy test: %f'%(EPOCH, total_loss_test, total_accuracy_test))\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'beamsearch-luong-normalize/model.ckpt'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, \"beamsearch-luong-normalize/model.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'Placeholder' in n.name\n",
    "        or 'logits' in n.name\n",
    "        or 'alphas' in n.name)\n",
    "        and 'Adam' not in n.name\n",
    "        and 'beta' not in n.name\n",
    "        and 'OptimizeLoss' not in n.name\n",
    "        and 'Global_Step' not in n.name\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            \"directory: %s\" % model_dir)\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "    \n",
    "    absolute_model_dir = \"/\".join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + \"/frozen_model.pb\"\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph=tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(\",\")\n",
    "        ) \n",
    "        with tf.gfile.GFile(output_graph, \"wb\") as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print(\"%d ops in the final graph.\" % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from beamsearch-luong-normalize/model.ckpt\n",
      "INFO:tensorflow:Froze 14 variables.\n",
      "INFO:tensorflow:Converted 14 variables to const ops.\n",
      "1739 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph(\"beamsearch-luong-normalize\", strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_graph(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, \"rb\") as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "    return graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "g=load_graph('beamsearch-luong-normalize/frozen_model.pb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PREDICTED AFTER: makin\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py:1702: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    }
   ],
   "source": [
    "x = g.get_tensor_by_name('import/Placeholder:0')\n",
    "logits = g.get_tensor_by_name('import/logits:0')\n",
    "test_sess = tf.InteractiveSession(graph=g)\n",
    "predicted = test_sess.run(logits,feed_dict={x:str_idx(['makn'],dictionary_from)})[0]\n",
    "print('PREDICTED AFTER:',''.join([rev_dictionary_to[n] for n in predicted if n not in[0,1,2,3]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PREDICTED AFTER: kecomelkan\n"
     ]
    }
   ],
   "source": [
    "x = g.get_tensor_by_name('import/Placeholder:0')\n",
    "logits = g.get_tensor_by_name('import/logits:0')\n",
    "test_sess = tf.InteractiveSession(graph=g)\n",
    "predicted = test_sess.run(logits,feed_dict={x:str_idx(['kecomelkn'],dictionary_from)})[0]\n",
    "print('PREDICTED AFTER:',''.join([rev_dictionary_to[n] for n in predicted if n not in[0,1,2,3]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PREDICTED AFTER: tak jadi\n"
     ]
    }
   ],
   "source": [
    "predicted = test_sess.run(logits,feed_dict={x:str_idx(['xjdi'],dictionary_from)})[0]\n",
    "print('PREDICTED AFTER:',''.join([rev_dictionary_to[n] for n in predicted if n not in[0,1,2,3]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('beamsearch-luong-normalize.json','w') as fopen:\n",
    "    fopen.write(json.dumps({'dictionary_from':dictionary_from,\n",
    "                'dictionary_to':dictionary_to,\n",
    "                'rev_dictionary_to':rev_dictionary_to,\n",
    "                'rev_dictionary_from':rev_dictionary_from}))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
